from enum import Enum
from functools import partial

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from . import (
    default_hooks as default,
    powerSGD_hook as powerSGD,
    quantization_hooks as quantization,
)


def _ddp_comm_hook_wrapper(comm_hook, model, state):
    model.register_comm_hook(state, comm_hook)


def _powerSGD_comm_hook_wrapper(
    comm_hook,
    model,
    state,
    matrix_approximation_rank,
    use_error_feedback=True,
    random_seed=0,
):
    """
    To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
    which will be wrapped up with other state info.
    """
    powerSGD_state = powerSGD.PowerSGDState(
        process_group=state,
        matrix_approximation_rank=matrix_approximation_rank,
        use_error_feedback=use_error_feedback,
        random_seed=random_seed,
    )
    model.register_comm_hook(powerSGD_state, comm_hook)


class DDPCommHookType(Enum):
    """
    DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
    as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example,
    you can register allreduce hook by
    ``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``.
    """

    ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
    FP16_COMPRESS = partial(
        _ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook
    )
    QUANTIZE_PER_TENSOR = partial(
        _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
    )
    QUANTIZE_PER_CHANNEL = partial(
        _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
    )
    POWER_SGD = partial(
        _powerSGD_comm_hook_wrapper,
        comm_hook=powerSGD.powerSGD_hook,
        matrix_approximation_rank=1,
    )
    # Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version,
    # but it runs slower and consumes more memory.
    POWER_SGD_RANK2 = partial(
        _powerSGD_comm_hook_wrapper,
        comm_hook=powerSGD.powerSGD_hook,
        matrix_approximation_rank=2,
    )
    # Batching can lead to a faster training at the cost of accuracy.
    BATCHED_POWER_SGD = partial(
        _powerSGD_comm_hook_wrapper,
        comm_hook=powerSGD.batched_powerSGD_hook,
        matrix_approximation_rank=1,
    )
    BATCHED_POWER_SGD_RANK2 = partial(
        _powerSGD_comm_hook_wrapper,
        comm_hook=powerSGD.batched_powerSGD_hook,
        matrix_approximation_rank=2,
    )


def register_ddp_comm_hook(
    comm_hook_type: DDPCommHookType, model: DistributedDataParallel, state=None
):
    """
    Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
    to the DDP model. User can specify the type of hook as an enum
    ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will
    be passed to the model.
    Uses Python comm hook implementations.

    Example::
        >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state)
    """
    comm_hook_type.value(model=model, state=state)
